-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Enabling word-level timestamps for Wav2Vec 2.0 #3627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
thanks - could you also add it to decoders here, which are soon to replace the old ones you improved: https://github.com/pytorch/fairseq/tree/master/examples/speech_recognition/new/decoders |
Done! By the way, I only added this to the KenLM decoder. Do you think the same approach would work for FairseqLMDecoder? |
yes, it should work for all decoders |
@alexeib has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Hey @Nithin-Holla, How to interpret these timesteps? Are the timesteps in seconds? |
@harveenchadha These timesteps are similar to the ones returned by the ctcdecode library. They indicate the frame number corresponding to each of the predicted characters. Suppose you have the start time and duration of an audio segment, one simple way to convert these timesteps into seconds that I can think of is |
Hey @Nithin-Holla |
@hpaliwal1225 I haven't checked it for the Viterbi decoder. But if the expression |
@Nithin-Holla , thanks for this awesome contribution. Could you share a working code example? I've been trying to work through some of the code in the speech recognition page but realize I'm in over my head. In particular, I'm not sure why I need to download and pre-process the entire LibriSpeech corpus to run inference on a tiny wav file. If you could share any code/example for running inference (with timings) on a small wav file, that'd be really helpful. Thanks much! |
this segment_start + timestep/total_frames * segment_duration` only works when file duration is less than 60s. what to do if file is greater than 60s? |
Hi guys, could anyone please explain to me where I can get the values for |
I figured out how to get the word-level timestamps from the fairseq
def get_timesteps(self, token_idxs: List[int]) -> List[int]:
"""Returns frame numbers corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of decoded tokens (including blank tokens), i.e. list of tokens spanning all frames of the emission matrix.
Returns
-------
List[int]
Frame numbers corresponding to every non-blank token.
"""
timesteps = []
for i, token_idx in enumerate(token_idxs):
if token_idx == self.blank:
continue
if i == 0 or token_idx != token_idxs[i-1]:
timesteps.append(i)
return timesteps
def get_symbols(self, token_idxs: List[int]) -> List[int]:
"""Returns characters corresponding to every non-blank token.
Parameters
----------
token_idxs : List[int]
IDs of non-blank tokens.
Returns
-------
List[int]
Character corresponding to every non-blank token.
"""
chars = []
for token_idx in token_idxs:
chars.append(self.symbols[token_idx])
return chars
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
if self.asg_transitions is None:
transitions = torch.FloatTensor(N, N).zero_()
else:
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
viterbi_path = torch.IntTensor(B, T)
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
CpuViterbiPath.compute(
B,
T,
N,
get_data_ptr_as_bytes(emissions),
get_data_ptr_as_bytes(transitions),
get_data_ptr_as_bytes(viterbi_path),
get_data_ptr_as_bytes(workspace),
)
for b in range(B):
tokens = self.get_tokens(viterbi_path[b].tolist()).tolist()
hypos.append(
[
{
"tokens": tokens, # non-blank token idxs.
"symbols": self.get_symbols(
tokens
), # characters (symbols) corresponding to non-blank token idxs.
"score": 0,
"timesteps": self.get_timesteps(
viterbi_path[b].tolist()
), # frame numbers of non-blank tokens.
"words": post_process(
self.tgt_dict.string(tokens), "letter"
).split(
" "
), # the transcript as a list of words.
}
]
)
return hypos
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append(
[
{
"tokens": tokens, # non-blank token idxs.
"symbols": self.get_symbols(
tokens
), # characters (symbols) corresponding to non-blank token idxs.
"score": result.score,
"timesteps": self.get_timesteps(
result.tokens
), # frame numbers of non-blank tokens.
"words": [
self.word_dict.get_entry(x) for x in result.words if x >= 0
], # the transcript as a list of words. Empty if lexicon-free decoding.
}
for result in nbest_results
if (
tokens := self.get_tokens(result.tokens).tolist()
) # tokens is a local variable for the list comprehension.
]
)
return hypos
def decode(self, emissions):
B, T, N = emissions.size()
hypos = []
def idx_to_word(idx):
if self.unit_lm:
return self.idx_to_wrd[idx]
else:
return self.word_dict[idx]
def make_hypo(result):
hypo = {
"tokens": self.get_tokens(result.tokens).tolist(), # non-blank token idxs.
"score": result.score
}
if self.lexicon:
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0] # the transcript as a list of words.
hypo["symbols"] = self.get_symbols(hypo["tokens"]) # characters (symbols) corresponding to non-blank token idxs.
hypo["timesteps"] = self.get_timesteps(result.tokens) # frame numbers of non-blank tokens.
return hypo
for b in range(B):
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, T, N)
nbest_results = results[: self.nbest]
hypos.append([make_hypo(result) for result in nbest_results])
self.lm.empty_cache()
return hypos I then postprocess the results in my own custom script to get the word-level time alignments (in seconds) for each hypothesis:
def beam_search_decode_fairseq(hypos, emission_mx, audio_lens, num_hyps, time_aligns):
"""Process the results of a W2lDecoder object from fairseq.
Args:
hypos (Union[List[Dict], List[List[Dict]]]):
List of results for each audio file returned by a W2lDecoder object. If the number of hypotheses to return (W2lDecoder.nbest) is 1, hypos will be a list of just the best hypotheses dicts.
If W2lDecoder.nbest > 1, hypos will be a list of lists, where for each audio file there will be N best hypotheses dicts.
emission_mx (torch.tensor(B,T,N)):
The batched emission matrix outputted by the w2v2 acoustic model trained in fairseq.
audio_lens (List[int]):
The lengths of the original audio files in the batch, measured in number of samples.
num_hyps (int):
The number of best hypotheses to return per audio file.
time_aligns (bool):
Flag used to specify whether to calculate word-level time alignment in seconds for each hypothesis.
Returns:
transcripts (Union[List[Dict], List[List[Dict]]]):
List of processed results for each audio file. If W2lDecoder.nbest = 1, transcripts will be a list of just the best hypotheses dicts.
If W2lDecoder.nbest > 1, transcripts will be a list of lists, where for each audio file there will be N best hypotheses dicts.
A hypothesis dict has the following fields:
'pred_txt': (str) the transcript hypothesis itself.
'timestamps_word': (List[Dict]) List of word Dict objects, one for each word in the transcript, with the following fields:
'word': the word itself.
'start_time': the start time of the word in seconds in the corresponding audio file.
'end_time': the end time of the word in seconds in the corresponding audio file.
"""
transcripts = []
for i in range(emission_mx.size(dim=0)):
# if the batch_size is > 1, use the maximum original audio length in the batch, as all other audio files are padded to the max length during preprocessing.
audio_len = audio_lens[i] if emission_mx.size(dim=0) == 1 else max(audio_lens)
if num_hyps > 1:
all_results = []
for hyp in hypos[i]:
hyp_dict = dict()
if hyp['words']:
# 'words' field is not empty if using a lexicon.
transcript = ' '.join(hyp['words']).lower()
else:
# 'words' field is [] if lexicon-free decoding, convert from non-blank symbols to words instead.
tokens_str = ''.join(hyp['symbols'])
transcript = ' '.join(tokens_str.split('|')).strip().lower()
hyp_dict['pred_txt'] = transcript
if time_aligns:
word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hyp['symbols'], hyp['timesteps'])
timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
hyp_dict['timestamps_word'] = timestamps_word
# add a hypothesis dict
all_results.append(hyp_dict)
transcripts.append(all_results)
else:
hyp_dict = dict()
# append the decoded phrase (as a list of words) from the prediction of the first beam [0] (most likely transcript).
if hypos[i][0]['words']:
# 'words' field is not empty if using a lexicon.
transcript = ' '.join(hypos[i][0]['words']).lower()
else:
# 'words' field is [] if lexicon-free decoding, convert from non-blank symbols to words instead.
tokens_str = ''.join(hypos[i][0]['symbols'])
transcript = ' '.join(tokens_str.split('|')).strip().lower()
hyp_dict['pred_txt'] = transcript
if time_aligns:
word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hypos[i][0]['symbols'], hypos[i][0]['timesteps'])
timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
hyp_dict['timestamps_word'] = timestamps_word
# add a hypothesis dict
transcripts.append(hyp_dict)
return transcripts
transcripts = []
for i in range(emission_mx.size(dim=0)):
# if the batch_size is > 1, use the maximum original audio length in the batch, as all other audio files are padded to the max length during preprocessing.
audio_len = audio_lens[i] if emission_mx.size(dim=0) == 1 else max(audio_lens)
hyp_dict = dict()
# append the decoded phrase (as a list of words) from the prediction of the first beam [0] (most likely transcript).
transcript = ' '.join(hypos[i][0]['words']).lower()
hyp_dict['pred_txt'] = transcript
if self.time_aligns:
word_times = get_word_time_alignments_fairseq(audio_len, emission_mx.size(dim=1), 16000, hypos[i][0]['symbols'], hypos[i][0]['timesteps'])
timestamps_word = normalize_timestamp_output_w2v2(hyp_dict['pred_txt'].split(' '), word_times)
hyp_dict['timestamps_word'] = timestamps_word
# add a hypothesis dict
transcripts.append(hyp_dict)
return transcripts Most importantly, def get_word_time_alignments_fairseq(audio_len, num_frames, sample_rate, symbols, timesteps):
"""Get word time alignments information for a hypothesis transcript input by converting from timesteps to seconds.
Args:
audio_len (int):
The length of audio file in number of samples.
num_frames (int):
The number of frames in the ASR acoustic model emission matrix.
sample_rate (int):
The sample rate of the loaded audio file.
symbols (List[str]):
Decoded list of characters corresponding to the non-blank tokens returned by the decoder.
timesteps (List[int]):
Frame numbers corresponding to the non-blank tokens/symbols.
Returns:
word_times (List[Tuple[float, float]]):
List of tuples of start_time and stop_time in seconds for word in the transcript.
"""
# list of times in seconds in the corresponding audio file for the the non-blank tokens/symbols.
timestamps = []
# get the timestep in seconds corresponding to each non-blank token.
for frame_num in timesteps:
timestamp = frame_num * (audio_len / (num_frames * sample_rate))
timestamps.append(timestamp)
# NOTE: algorithm only works if the first and last symbols are '|', so add them in if that's not the case.
frame_offset = 0
if symbols[0] != '|':
symbols.insert(0, '|')
# if adding a symbol at index 0, all symbols will have their frame idx increased by 1, so an offset of -1 is created.
frame_offset = -1
if symbols[-1] != '|':
symbols.append('|')
word_boundary_idxs = [] # tuples of word start and stop indices.
# get the indices of all word-boundary tokens (|).
wb_tokens_idxs = [i for i in range(len(symbols)) if symbols[i] == '|']
# create tuples for each word that contains the indices of its start symbol and end symbol.
tup = [] # initialise the first tuple of word start character and word end character indices.
# loop through the indices of the '|' tokens and find the indices of the word-boundary symbols/characters that are the start and end characters of each word.
for wb_tokens_idx in wb_tokens_idxs:
try:
if symbols[wb_tokens_idx-1] != '|' and tup:
# there is a start index in tuple, but no end index yet.
# end index has been found.
if wb_tokens_idx-1 == tup[0]:
# word is composed of only one character, add the index of this '|' token as the end character index for the word.
tup.append(wb_tokens_idx)
else:
# word is composed of more than one character.
tup.append(wb_tokens_idx-1) # add an end character index for the word.
# add the tuple as complete word to the list of word start and end index tuples.
word_boundary_idxs.append(tup)
tup = [] # reset the tuple.
# continue onto the next if statement as this '|' token may be the boundary between two words.
if symbols[wb_tokens_idx+1] != '|':
# start character of new word reached.
tup.append(wb_tokens_idx+1) # add a start character index for the word.
except IndexError:
continue
# create tuples of start and stop times for each word
word_times = [(timestamps[start_idx + frame_offset], timestamps[end_idx + frame_offset]) for start_idx, end_idx in word_boundary_idxs]
return word_times And def normalize_timestamp_output_w2v2(words, word_time_tuples):
"""Get word Dict objects with time information for each word in the hypothesis transcript.
Args:
words (List[str]):
List of words in the transcript.
word_time_tuples (List[Tuple[float,float]]):
List of tuples of start_time and stop_time in seconds for word in the transcript.
Returns:
values (List[Dict]):
List of dict objects where each dict has the following fields:
'word': (str) the word itself.
'start_time': (float) the start time in seconds of the word in the corresponding audio file.
'end_time': (float) the end time in seconds of the word in the corresponding audio file.
"""
values = []
for word, (word_start, word_end) in zip(words, word_time_tuples):
vals_dict = dict()
vals_dict['word'] = word
vals_dict['start_time'] = word_start
vals_dict['end_time'] = word_end
values.append(vals_dict)
return values The formula I use to calculate the time in seconds in the corresponding audio for each non-blank symbol in the transcript is the following:
|
Before submitting
What does this PR do?
Fixes #3371.
Currently, the output from Wav2Vec 2.0 decoding does not contain word-level start/end times, which can be useful for certain applications of ASR. Based on the discussion here, they could be computed based on the output from the Flashlight decoder. For the KenLM decoder, we could first obtain the frame number corresponding to each non-blank token. Next, the timestamp of each character could be computed as
segment_start + frame_no/total_frames * segment_duration
. Finally, the start and end time of each word could be calculated based on the timestamp of the word boundary characters. In order to enable this, the frame number of each non-blank character is returned as a result of KenLM decoding. This is similar to thetimesteps
output from the ctcdecode library.PR review
@alexeib